{
 "cells": [
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "### Dependencies"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 11,
   "metadata": {},
   "outputs": [],
   "source": [
    "# Base Dependencies\n",
    "import os\n",
    "import pickle\n",
    "import sys\n",
    "j_ = os.path.join\n",
    "\n",
    "# LinAlg / Stats / Plotting Dependencies\n",
    "import matplotlib.pyplot as plt\n",
    "import numpy as np\n",
    "import pandas as pd\n",
    "from PIL import Image\n",
    "from tqdm import tqdm\n",
    "\n",
    "# Scikit-Learn Imports\n",
    "import sklearn\n",
    "from sklearn.linear_model import LogisticRegression\n",
    "from sklearn.neighbors import KNeighborsClassifier\n",
    "from sklearn.preprocessing import LabelEncoder\n",
    "from sklearn.model_selection import cross_val_score, StratifiedKFold\n",
    "\n",
    "#Torch Imports\n",
    "import torch\n",
    "import torch.nn as nn\n",
    "from torch.utils.data.dataset import Dataset\n",
    "torch.multiprocessing.set_sharing_strategy('file_system')\n",
    "\n",
    "# Utils\n",
    "from slide_extraction_utils import create_slide_embeddings\n",
    "from slide_evaluation_utils import get_knn_classification_results"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "### Saving Each WSI Bag as it's Mean Instance-Level Embedding"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 7,
   "metadata": {},
   "outputs": [
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "Extracting resnet50mean embedddings for tcga_brca\n"
     ]
    },
    {
     "name": "stderr",
     "output_type": "stream",
     "text": [
      "100%|██████████| 10/10 [00:00<00:00, 52891.60it/s]\n"
     ]
    },
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "Extracting resnet50mean embedddings for tcga_kidney\n"
     ]
    },
    {
     "name": "stderr",
     "output_type": "stream",
     "text": [
      "100%|██████████| 10/10 [00:00<00:00, 116185.71it/s]\n"
     ]
    },
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "Extracting resnet50mean embedddings for tcga_lung\n"
     ]
    },
    {
     "name": "stderr",
     "output_type": "stream",
     "text": [
      "100%|██████████| 10/10 [00:00<00:00, 63358.07it/s]\n"
     ]
    },
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "Extracting vit16mean embedddings for tcga_brca\n"
     ]
    },
    {
     "name": "stderr",
     "output_type": "stream",
     "text": [
      "100%|██████████| 10/10 [00:00<00:00, 47180.02it/s]\n"
     ]
    },
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "Extracting vit16mean embedddings for tcga_kidney\n"
     ]
    },
    {
     "name": "stderr",
     "output_type": "stream",
     "text": [
      "100%|██████████| 10/10 [00:00<00:00, 66052.03it/s]"
     ]
    },
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "Extracting vit16mean embedddings for tcga_lung\n"
     ]
    },
    {
     "name": "stderr",
     "output_type": "stream",
     "text": [
      "\n",
      "100%|██████████| 10/10 [00:00<00:00, 248183.67it/s]\n"
     ]
    },
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "Extracting vit256mean embedddings for tcga_brca\n"
     ]
    },
    {
     "name": "stderr",
     "output_type": "stream",
     "text": [
      "100%|██████████| 10/10 [00:00<00:00, 115545.56it/s]\n"
     ]
    },
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "Extracting vit256mean embedddings for tcga_kidney\n"
     ]
    },
    {
     "name": "stderr",
     "output_type": "stream",
     "text": [
      "100%|██████████| 10/10 [00:00<00:00, 161942.24it/s]\n"
     ]
    },
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "Extracting vit256mean embedddings for tcga_lung\n"
     ]
    },
    {
     "name": "stderr",
     "output_type": "stream",
     "text": [
      "100%|██████████| 10/10 [00:00<00:00, 86659.17it/s]\n"
     ]
    }
   ],
   "source": [
    "r\"\"\"\n",
    "Script for saving mean WSI features for each feature type in each task\n",
    "\"\"\"\n",
    "\n",
    "dataroot = './embeddings_slide_lib/'\n",
    "saveroot = './embeddings_slide_lib/knn-subtyping/'\n",
    "os.makedirs(saveroot, exist_ok=True)\n",
    "\n",
    "for enc_name in ['resnet50mean', 'vit16mean', 'vit256mean']:\n",
    "    for study in ['tcga_brca', 'tcga_kidney', 'tcga_lung']:\n",
    "        print(f'Extracting {enc_name} embedddings for {study}')\n",
    "        create_slide_embeddings(dataroot=dataroot, saveroot=saveroot,\n",
    "                                enc_name=enc_name, study=study)"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "### 10-Fold CV Evaluation of Mean WSI Embeddings"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 6,
   "metadata": {},
   "outputs": [
    {
     "name": "stderr",
     "output_type": "stream",
     "text": [
      "100%|█████████████████████████████████████████████| 3/3 [00:03<00:00,  1.21s/it]"
     ]
    },
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "\\begin{tabular}{lllllllll}\n",
      "\\toprule\n",
      "{} &  Pretrain &       Arch &             0.25 &              1.0 &             0.25 &              1.0 &             0.25 &              1.0 \\\\\n",
      "\\midrule\n",
      "{} &  ImageNet &  ResNet-50 &  0.638 +/- 0.089 &  0.667 +/- 0.070 &  0.696 +/- 0.055 &  0.794 +/- 0.035 &  0.862 +/- 0.030 &  0.951 +/- 0.016 \\\\\n",
      "{} &      DINO &     ViT-16 &  0.605 +/- 0.092 &  0.725 +/- 0.083 &  0.622 +/- 0.067 &  0.742 +/- 0.045 &  0.848 +/- 0.032 &  0.899 +/- 0.027 \\\\\n",
      "{} &      DINO &    ViT-256 &  0.682 +/- 0.055 &  0.775 +/- 0.042 &  0.773 +/- 0.048 &  0.889 +/- 0.027 &  0.916 +/- 0.022 &  0.974 +/- 0.016 \\\\\n",
      "\\bottomrule\n",
      "\\end{tabular}\n",
      "\n"
     ]
    },
    {
     "name": "stderr",
     "output_type": "stream",
     "text": [
      "\n"
     ]
    },
    {
     "data": {
      "text/html": [
       "<div>\n",
       "<style scoped>\n",
       "    .dataframe tbody tr th:only-of-type {\n",
       "        vertical-align: middle;\n",
       "    }\n",
       "\n",
       "    .dataframe tbody tr th {\n",
       "        vertical-align: top;\n",
       "    }\n",
       "\n",
       "    .dataframe thead th {\n",
       "        text-align: right;\n",
       "    }\n",
       "</style>\n",
       "<table border=\"1\" class=\"dataframe\">\n",
       "  <thead>\n",
       "    <tr style=\"text-align: right;\">\n",
       "      <th></th>\n",
       "      <th>Pretrain</th>\n",
       "      <th>Arch</th>\n",
       "      <th>0.25</th>\n",
       "      <th>1.0</th>\n",
       "      <th>0.25</th>\n",
       "      <th>1.0</th>\n",
       "      <th>0.25</th>\n",
       "      <th>1.0</th>\n",
       "    </tr>\n",
       "  </thead>\n",
       "  <tbody>\n",
       "    <tr>\n",
       "      <th></th>\n",
       "      <td>ImageNet</td>\n",
       "      <td>ResNet-50</td>\n",
       "      <td>0.638 +/- 0.089</td>\n",
       "      <td>0.667 +/- 0.070</td>\n",
       "      <td>0.696 +/- 0.055</td>\n",
       "      <td>0.794 +/- 0.035</td>\n",
       "      <td>0.862 +/- 0.030</td>\n",
       "      <td>0.951 +/- 0.016</td>\n",
       "    </tr>\n",
       "    <tr>\n",
       "      <th></th>\n",
       "      <td>DINO</td>\n",
       "      <td>ViT-16</td>\n",
       "      <td>0.605 +/- 0.092</td>\n",
       "      <td>0.725 +/- 0.083</td>\n",
       "      <td>0.622 +/- 0.067</td>\n",
       "      <td>0.742 +/- 0.045</td>\n",
       "      <td>0.848 +/- 0.032</td>\n",
       "      <td>0.899 +/- 0.027</td>\n",
       "    </tr>\n",
       "    <tr>\n",
       "      <th></th>\n",
       "      <td>DINO</td>\n",
       "      <td>ViT-256</td>\n",
       "      <td>0.682 +/- 0.055</td>\n",
       "      <td>0.775 +/- 0.042</td>\n",
       "      <td>0.773 +/- 0.048</td>\n",
       "      <td>0.889 +/- 0.027</td>\n",
       "      <td>0.916 +/- 0.022</td>\n",
       "      <td>0.974 +/- 0.016</td>\n",
       "    </tr>\n",
       "  </tbody>\n",
       "</table>\n",
       "</div>"
      ],
      "text/plain": [
       "  Pretrain       Arch             0.25              1.0             0.25  \\\n",
       "  ImageNet  ResNet-50  0.638 +/- 0.089  0.667 +/- 0.070  0.696 +/- 0.055   \n",
       "      DINO     ViT-16  0.605 +/- 0.092  0.725 +/- 0.083  0.622 +/- 0.067   \n",
       "      DINO    ViT-256  0.682 +/- 0.055  0.775 +/- 0.042  0.773 +/- 0.048   \n",
       "\n",
       "              1.0             0.25              1.0  \n",
       "  0.794 +/- 0.035  0.862 +/- 0.030  0.951 +/- 0.016  \n",
       "  0.742 +/- 0.045  0.848 +/- 0.032  0.899 +/- 0.027  \n",
       "  0.889 +/- 0.027  0.916 +/- 0.022  0.974 +/- 0.016  "
      ]
     },
     "execution_count": 6,
     "metadata": {},
     "output_type": "execute_result"
    }
   ],
   "source": [
    "r\"\"\"\n",
    "Script for runnign 10-fold CV for each feature type for each TCGA study.\n",
    "\"\"\"\n",
    "    \n",
    "results_all = []\n",
    "dataroot = './embeddings_slide_lib/knn-subtyping/'\n",
    "\n",
    "for enc_name in tqdm(['resnet50mean', 'vit16mean', 'vit256mean']):\n",
    "    results_row = []\n",
    "    for study in ['tcga_brca', 'tcga_lung', 'tcga_kidney']:\n",
    "        for prop in [0.25, 1.0]:\n",
    "            aucs = get_knn_classification_results(dataroot, study, enc_name, prop)\n",
    "            aucs = '%0.3f +/- %0.3f' % (aucs.mean(), aucs.std())\n",
    "            results_row.append([aucs])\n",
    "    \n",
    "    results_all.append(pd.DataFrame(results_row).T)\n",
    "    \n",
    "results_df = pd.concat(results_all)\n",
    "results_df.index = ['resnet50mean', 'vit16mean', 'vit256mean']\n",
    "results_df.columns = [0.25, 1.0, 0.25, 1.0, 0.25, 1.0]\n",
    "results_df.index = ['', '', '']\n",
    "results_df.insert(0, 'Pretrain', ['ImageNet', 'DINO', 'DINO'])\n",
    "results_df.insert(1, 'Arch', ['ResNet-50','ViT-16', 'ViT-256'])\n",
    "print(results_df.to_latex())\n",
    "results_df"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "### 1st Fold Evaluation of Mean WSI Embeddings"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 7,
   "metadata": {},
   "outputs": [
    {
     "name": "stderr",
     "output_type": "stream",
     "text": [
      "100%|██████████| 3/3 [00:01<00:00,  2.43it/s]"
     ]
    },
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "\\begin{tabular}{lllllllll}\n",
      "\\toprule\n",
      "{} &  Pretrain &       Arch &   0.25 &    1.0 &   0.25 &    1.0 &   0.25 &    1.0 \\\\\n",
      "\\midrule\n",
      "{} &  ImageNet &  ResNet-50 &  0.706 &  0.566 &  0.681 &  0.789 &  0.867 &  0.947 \\\\\n",
      "{} &      DINO &     ViT-16 &  0.719 &  0.833 &  0.586 &  0.668 &  0.855 &  0.892 \\\\\n",
      "{} &      DINO &    ViT-256 &  0.711 &  0.808 &  0.728 &  0.947 &  0.929 &  0.979 \\\\\n",
      "\\bottomrule\n",
      "\\end{tabular}\n",
      "\n"
     ]
    },
    {
     "name": "stderr",
     "output_type": "stream",
     "text": [
      "\n"
     ]
    },
    {
     "data": {
      "text/html": [
       "<div>\n",
       "<style scoped>\n",
       "    .dataframe tbody tr th:only-of-type {\n",
       "        vertical-align: middle;\n",
       "    }\n",
       "\n",
       "    .dataframe tbody tr th {\n",
       "        vertical-align: top;\n",
       "    }\n",
       "\n",
       "    .dataframe thead th {\n",
       "        text-align: right;\n",
       "    }\n",
       "</style>\n",
       "<table border=\"1\" class=\"dataframe\">\n",
       "  <thead>\n",
       "    <tr style=\"text-align: right;\">\n",
       "      <th></th>\n",
       "      <th>Pretrain</th>\n",
       "      <th>Arch</th>\n",
       "      <th>0.25</th>\n",
       "      <th>1.0</th>\n",
       "      <th>0.25</th>\n",
       "      <th>1.0</th>\n",
       "      <th>0.25</th>\n",
       "      <th>1.0</th>\n",
       "    </tr>\n",
       "  </thead>\n",
       "  <tbody>\n",
       "    <tr>\n",
       "      <th></th>\n",
       "      <td>ImageNet</td>\n",
       "      <td>ResNet-50</td>\n",
       "      <td>0.706</td>\n",
       "      <td>0.566</td>\n",
       "      <td>0.681</td>\n",
       "      <td>0.789</td>\n",
       "      <td>0.867</td>\n",
       "      <td>0.947</td>\n",
       "    </tr>\n",
       "    <tr>\n",
       "      <th></th>\n",
       "      <td>DINO</td>\n",
       "      <td>ViT-16</td>\n",
       "      <td>0.719</td>\n",
       "      <td>0.833</td>\n",
       "      <td>0.586</td>\n",
       "      <td>0.668</td>\n",
       "      <td>0.855</td>\n",
       "      <td>0.892</td>\n",
       "    </tr>\n",
       "    <tr>\n",
       "      <th></th>\n",
       "      <td>DINO</td>\n",
       "      <td>ViT-256</td>\n",
       "      <td>0.711</td>\n",
       "      <td>0.808</td>\n",
       "      <td>0.728</td>\n",
       "      <td>0.947</td>\n",
       "      <td>0.929</td>\n",
       "      <td>0.979</td>\n",
       "    </tr>\n",
       "  </tbody>\n",
       "</table>\n",
       "</div>"
      ],
      "text/plain": [
       "  Pretrain       Arch   0.25    1.0   0.25    1.0   0.25    1.0\n",
       "  ImageNet  ResNet-50  0.706  0.566  0.681  0.789  0.867  0.947\n",
       "      DINO     ViT-16  0.719  0.833  0.586  0.668  0.855  0.892\n",
       "      DINO    ViT-256  0.711  0.808  0.728  0.947  0.929  0.979"
      ]
     },
     "execution_count": 7,
     "metadata": {},
     "output_type": "execute_result"
    }
   ],
   "source": [
    "r\"\"\"\n",
    "Script for running single-fold CV for each feature type for each TCGA study.\n",
    "\"\"\"\n",
    "    \n",
    "results_all = []\n",
    "dataroot = './embeddings_slide_lib/knn-subtyping/'\n",
    "\n",
    "for enc_name in tqdm(['resnet50mean', 'vit16mean', 'vit256mean']):\n",
    "    results_row = []\n",
    "    for study in ['tcga_brca', 'tcga_lung', 'tcga_kidney']:\n",
    "        for prop in [0.25, 1.0]:\n",
    "            aucs = get_knn_classification_results(dataroot, study, enc_name, prop)\n",
    "            aucs = '%0.3f' % (aucs.iloc[0][0])\n",
    "            results_row.append([aucs])\n",
    "    \n",
    "    results_all.append(pd.DataFrame(results_row).T)\n",
    "    \n",
    "results_df = pd.concat(results_all)\n",
    "results_df.index = ['resnet50mean', 'vit16mean', 'vit256mean']\n",
    "results_df.columns = [0.25, 1.0, 0.25, 1.0, 0.25, 1.0]\n",
    "results_df.index = ['', '', '']\n",
    "results_df.insert(0, 'Pretrain', ['ImageNet', 'DINO', 'DINO'])\n",
    "results_df.insert(1, 'Arch', ['ResNet-50','ViT-16', 'ViT-256'])\n",
    "print(results_df.to_latex())\n",
    "results_df"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "### Faster Sanity-Check that the Results are Correct"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 12,
   "metadata": {},
   "outputs": [
    {
     "name": "stderr",
     "output_type": "stream",
     "text": [
      "100%|████████████████████████████████████████| 937/937 [00:01<00:00, 878.24it/s]\n"
     ]
    },
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "tcga_brca 0.7738300898746104\n"
     ]
    },
    {
     "name": "stderr",
     "output_type": "stream",
     "text": [
      "100%|████████████████████████████████████████| 905/905 [00:01<00:00, 663.50it/s]\n"
     ]
    },
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "tcga_kidney 0.9739114652064474\n"
     ]
    },
    {
     "name": "stderr",
     "output_type": "stream",
     "text": [
      "100%|████████████████████████████████████████| 958/958 [00:01<00:00, 754.76it/s]\n"
     ]
    },
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "tcga_lung 0.8935382487870264\n"
     ]
    }
   ],
   "source": [
    "r\"\"\"\n",
    "Script for running single-fold CV for each feature type for each TCGA study.\n",
    "\"\"\"\n",
    "\n",
    "dataroot = './embeddings_slide_lib/vit256mean_tcga_slide_embeddings/'\n",
    "available_vit256_features = os.listdir(dataroot)\n",
    "\n",
    "for study in ['tcga_brca', 'tcga_kidney', 'tcga_lung']:\n",
    "    path2csv = '../Weakly-Supervised-Subtyping/dataset_csv/'\n",
    "    dataset = pd.read_csv(j_(path2csv, f'{study}_subset.csv.zip'), index_col=2)\n",
    "    dataset.index = dataset.index.str[:-4]\n",
    "    embeddings_all, labels_all = [], []\n",
    "    slide_ids = []\n",
    "    \n",
    "    if study == 'tcga_brca':\n",
    "        label_dict={'IDC':0, 'ILC':1}\n",
    "    elif study == 'tcga_kidney':\n",
    "        label_dict={'CCRCC':0, 'PRCC':1, 'CHRCC': 2}\n",
    "    elif study == 'tcga_lung':\n",
    "        label_dict={'LUSC':1, 'LUAD':0}\n",
    "                          \n",
    "\n",
    "    for slide_id in tqdm(dataset.index):\n",
    "        pt_fname, label = slide_id+'.pt', dataset.loc[slide_id]['oncotree_code']\n",
    "        if (pt_fname in available_vit256_features) and (label in label_dict.keys()):        \n",
    "            vit256_features = torch.load(os.path.join(dataroot, pt_fname)).mean(axis=0)\n",
    "            embeddings_all.append(vit256_features)\n",
    "            labels_all.append(label_dict[label])\n",
    "            slide_ids.append(slide_id)\n",
    "\n",
    "    embeddings_all = torch.stack(embeddings_all).numpy()\n",
    "    labels_all = np.array(labels_all)             \n",
    "                          \n",
    "    clf = KNeighborsClassifier()\n",
    "    skf = StratifiedKFold(n_splits=10, shuffle=True, random_state=0)\n",
    "    \n",
    "    if len(label_dict.keys()) > 2:\n",
    "        scores = cross_val_score(clf, embeddings_all, labels_all, cv=skf, scoring='roc_auc_ovr')\n",
    "    else:\n",
    "        scores = cross_val_score(clf, embeddings_all, labels_all, cv=skf, scoring='roc_auc')\n",
    "        \n",
    "    print(study, scores.mean())"
   ]
  }
 ],
 "metadata": {
  "kernelspec": {
   "display_name": "Python 3 (ipykernel)",
   "language": "python",
   "name": "python3"
  },
  "language_info": {
   "codemirror_mode": {
    "name": "ipython",
    "version": 3
   },
   "file_extension": ".py",
   "mimetype": "text/x-python",
   "name": "python",
   "nbconvert_exporter": "python",
   "pygments_lexer": "ipython3",
   "version": "3.9.5"
  }
 },
 "nbformat": 4,
 "nbformat_minor": 4
}
